{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3d041443-ad35-430e-9a2c-6f0aca356d8e",
   "metadata": {},
   "source": [
    "Chapter 26\n",
    "# 可视化二元高斯分布PDF\n",
    "Book_1《编程不难》 | 鸢尾花书：从加减乘除到机器学习  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7001006d-25ad-469b-8368-986634b2f4a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import multivariate_normal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a53bcc0f-cd67-4cb9-b1d9-89cb12343e9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "rho_array = [-0.9, -0.7, -0.5, -0.3, \n",
    "             0, 0.3, 0.5, 0.7, 0.9]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a337bfdf-e749-426b-8685-cd1d1974f9b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "sigma_X = 1; sigma_Y = 1 # 标准差"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d977785-8665-414c-af83-ab6319743e30",
   "metadata": {},
   "outputs": [],
   "source": [
    "mu_X = 0;    mu_Y = 0    # 期望"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57ee4f67-4a9d-4023-a516-b2dbffbc3672",
   "metadata": {},
   "outputs": [],
   "source": [
    "width = 4\n",
    "X = np.linspace(-width,width,321)\n",
    "Y = np.linspace(-width,width,321)\n",
    "XX, YY = np.meshgrid(X, Y)\n",
    "XXYY = np.dstack((XX, YY))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08bdc8d1-89fb-4099-b7c9-e4ef197dfd21",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 曲面\n",
    "fig = plt.figure(figsize = (8,8))\n",
    "for idx, rho_idx in enumerate(rho_array):\n",
    "    # 质心\n",
    "    mu    = [mu_X, mu_Y] \n",
    "    # 协方差\n",
    "    Sigma = [[sigma_X**2, sigma_X*sigma_Y*rho_idx], \n",
    "            [sigma_X*sigma_Y*rho_idx, sigma_Y**2]]\n",
    "    # 二元高斯分布\n",
    "    bi_norm = multivariate_normal(mu, Sigma)\n",
    "    f_X_Y_joint = bi_norm.pdf(XXYY)\n",
    "    \n",
    "    ax = fig.add_subplot(3,3,idx+1,projection='3d')\n",
    "    ax.plot_wireframe(XX, YY, f_X_Y_joint,\n",
    "                      rstride=10, cstride=10,\n",
    "                      color = [0.3,0.3,0.3],\n",
    "                      linewidth = 0.25)\n",
    "\n",
    "    ax.contour(XX,YY, f_X_Y_joint,15,\n",
    "               cmap = 'RdYlBu_r')\n",
    "\n",
    "    ax.set_xlabel('$x$'); ax.set_ylabel('$y$')\n",
    "    ax.set_zlabel('$f_{X,Y}(x,y)$')\n",
    "    ax.view_init(azim=-120, elev=30)\n",
    "    ax.set_proj_type('ortho')\n",
    "\n",
    "    ax.set_xlim(-width, width); ax.set_ylim(-width, width)\n",
    "    ax.set_zlim(f_X_Y_joint.min(),f_X_Y_joint.max())\n",
    "    # ax.axis('off')\n",
    "    \n",
    "plt.tight_layout()\n",
    "# fig.savefig('二元高斯分布，曲面.svg', format='svg')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b581de0-6326-43c9-a070-ca20c8fedf1b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
